import numpy as np
import tensorflow as tf
from tensorflow.python.ops import tensor_array_ops, control_flow_ops


class Generator(object):
    def __init__(self, num_vocabulary, batch_size, emb_dim, hidden_dim,
                 sequence_length, start_token,
                 discriminator=None, g_embeddings=None,
                 learning_rate=0.01, reward_gamma=0.95):
        self.num_vocabulary = num_vocabulary
        self.batch_size = batch_size
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.sequence_length = sequence_length
        self.start_token = tf.constant([start_token] * self.batch_size, dtype=tf.int32)
        self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
        self.reward_gamma = reward_gamma
        self.g_params = []
        self.d_params = []
        self.discriminator = discriminator
        self.temperature = 1.0
        self.grad_clip = 5.0
        self.expected_reward = tf.Variable(tf.zeros([self.sequence_length]))

        with tf.variable_scope('generator'):
            # self.g_embeddings = tf.Variable(self.init_matrix([self.num_vocabulary, self.emb_dim]))
            self.g_embeddings = g_embeddings
            self.g_params.append(self.g_embeddings)
            self.g_recurrent_unit = self.create_recurrent_unit(self.g_params)  # maps h_tm1 to h_t for generator
            self.g_output_unit = self.create_output_unit(self.g_params)  # maps h_t to o_t (output token logits)

        # placeholder definition

        self.x = tf.placeholder(tf.int32, shape=[self.batch_size,
                                                 self.sequence_length])  # sequence of tokens generated by generator
        self.y = tf.placeholder(tf.int32, shape=[self.batch_size,
                                                 self.sequence_length])  # sequence of tokens of real data


        # processed for batch
        with tf.device("/cpu:0"):
            self.processed_x = tf.transpose(tf.nn.embedding_lookup(self.g_embeddings, self.x),
                                            perm=[1, 0, 2])  # seq_length x batch_size x emb_dim

        # Initial states
        self.h_0 = tf.placeholder(tf.float32, shape=[batch_size, emb_dim])
        self.c_0 = tf.placeholder(tf.float32, shape=[batch_size, emb_dim])
        self.h0 = tf.stack([self.h_0, self.c_0])

        gen_o = tensor_array_ops.TensorArray(dtype=tf.float32, size=self.sequence_length,
                                             dynamic_size=False, infer_shape=True)
        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length,
                                             dynamic_size=False, infer_shape=True)
        gen_ot = tensor_array_ops.TensorArray(dtype=tf.float32, size=self.sequence_length,
                                             dynamic_size=False, infer_shape=True)

        def _g_recurrence(i, x_t, h_tm1, gen_o, gen_x, gen_ot):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            o_t = self.g_output_unit(h_t)  # batch x vocab , logits not prob
            next_token = tf.cast(tf.argmax(o_t, axis=1), tf.int32)
            x_tp1 = tf.matmul(tf.nn.softmax(tf.multiply(o_t, 1e3)), self.g_embeddings)
            gen_o = gen_o.write(i, tf.reduce_sum(tf.multiply(tf.one_hot(next_token, self.num_vocabulary, 1.0, 0.0),
                                                             tf.nn.softmax(o_t)), 1))  # [batch_size] , prob
            gen_x = gen_x.write(i, next_token)  # indices, batch_size
            gen_ot = gen_ot.write(i, x_tp1)
            return i + 1, x_tp1, h_t, gen_o, gen_x, gen_ot

        _, _, _, self.gen_o, self.gen_x, self.gen_ot = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4, _5: i < self.sequence_length,
            body=_g_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings, self.start_token), self.h0, gen_o, gen_x, gen_ot))

        self.gen_x = self.gen_x.stack()  # seq_length x batch_size
        self.gen_x = tf.transpose(self.gen_x, perm=[1, 0])  # batch_size x seq_length

        self.gen_ot = self.gen_ot.stack()
        self.gen_ot = tf.slice(self.gen_ot,begin=[0,0,0],size=[sequence_length, batch_size, emb_dim])
        self.gen_ot = tf.transpose(self.gen_ot, perm=[1, 0, 2])  # batch_size x seq_length x g_emb_dim

        # supervised pretraining for generator
        g_predictions = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length,
            dynamic_size=False, infer_shape=True)

        ta_emb_x = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length)
        ta_emb_x = ta_emb_x.unstack(self.processed_x)

        def _pretrain_recurrence(i, x_t, h_tm1, g_predictions):
            h_t = self.g_recurrent_unit(x_t, h_tm1)
            o_t = self.g_output_unit(h_t)
            g_predictions = g_predictions.write(i, tf.nn.softmax(o_t))  # batch x vocab_size
            x_tp1 = ta_emb_x.read(i)
            return i + 1, x_tp1, h_t, g_predictions

        _, _, _, self.g_predictions = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.sequence_length,
            body=_pretrain_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings, self.start_token),
                       self.h0, g_predictions))

        self.g_predictions = tf.transpose(self.g_predictions.stack(),
                                          perm=[1, 0, 2])  # batch_size x seq_length x vocab_size

        # pretraining loss
        self.pretrain_loss = -tf.reduce_sum(
            tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_vocabulary, 1.0, 0.0) * tf.log(
                tf.clip_by_value(tf.reshape(self.g_predictions, [-1, self.num_vocabulary]), 1e-20, 1.0)
            )
        ) / (self.sequence_length * self.batch_size)

        # training updates
        pretrain_opt = self.g_optimizer(self.learning_rate)

        self.pretrain_grad, _ = tf.clip_by_global_norm(tf.gradients(self.pretrain_loss, self.g_params), self.grad_clip)
        self.pretrain_updates = pretrain_opt.apply_gradients(zip(self.pretrain_grad, self.g_params))

        #######################################################################################################
        #  Unsupervised Training
        #######################################################################################################

        def get_feature(input_x, name=''):
            return self.discriminator.feature(input_x=input_x, name=name)

        def compute_pairwise_distances(x, y):
            """Computes the squared pairwise Euclidean distances between x and y.
            Args:
              x: a tensor of shape [num_x_samples, num_features]
              y: a tensor of shape [num_y_samples, num_features]
            Returns:
              a distance matrix of dimensions [num_x_samples, num_y_samples].
            Raises:
              ValueError: if the inputs do no matched the specified dimensions.
            """

            if not len(x.get_shape()) == len(y.get_shape()) == 2:
                raise ValueError('Both inputs should be matrices.')

            if x.get_shape().as_list()[1] != y.get_shape().as_list()[1]:
                raise ValueError('The number of features should be the same.')

            norm = lambda x: tf.reduce_sum(tf.square(x), 1)

            # By making the `inner' dimensions of the two matrices equal to 1 using
            # broadcasting then we are essentially substracting every pair of rows
            # of x and y.
            # x will be num_samples x num_features x 1,
            # and y will be 1 x num_features x num_samples (after broadcasting).
            # After the substraction we will get a
            # num_x_samples x num_features x num_y_samples matrix.
            # The resulting dist will be of shape num_y_samples x num_x_samples.
            # and thus we need to transpose it again.
            return tf.transpose(norm(tf.expand_dims(x, 2) - tf.transpose(y)))

        def gaussian_kernel_matrix(x, y, sigmas=None):
            r"""Computes a Guassian Radial Basis Kernel between the samples of x and y.
            We create a sum of multiple gaussian kernels each having a width sigma_i.
            Args:
              x: a tensor of shape [num_samples, num_features]
              y: a tensor of shape [num_samples, num_features]
              sigmas: a tensor of floats which denote the widths of each of the
                gaussians in the kernel.
            Returns:
              A tensor of shape [num_samples{x}, num_samples{y}] with the RBF kernel.
            """
            if sigmas is None:
                sigmas = [
                    1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
                    1e3, 1e4, 1e5, 1e6
                ]
            beta = 1. / (2. * (tf.expand_dims(sigmas, 1)))

            dist = compute_pairwise_distances(x, y)

            s = tf.matmul(beta, tf.reshape(dist, (1, -1)))

            return tf.reshape(tf.reduce_sum(tf.exp(-s), 0), tf.shape(dist))

        def calc_mmd(x, y):
            cost = tf.reduce_mean(gaussian_kernel_matrix(x, x))
            cost += tf.reduce_mean(gaussian_kernel_matrix(y, y))
            cost -= 2 * tf.reduce_mean(gaussian_kernel_matrix(x, y))

            # We do not allow the loss to become negative.
            cost = tf.where(cost > 0, cost, 0, name='value')

            return cost


        x_feature = get_feature(input_x=self.gen_ot, name='gx')
        y_feature = get_feature(input_x=self.y, name='gy')
        self.mmd = calc_mmd(x_feature, y_feature)
        g_opt = self.g_optimizer(self.learning_rate)

        self.g_grad, _ = tf.clip_by_global_norm(tf.gradients(self.mmd, self.g_params), self.grad_clip)
        self.g_updates = g_opt.apply_gradients(zip(self.g_grad, self.g_params))

    def generate(self, sess, get_z = False):
        z_h0 = np.random.uniform(low=-.01, high=1, size=[self.batch_size, self.emb_dim])
        z_c0 = np.zeros(shape=[self.batch_size, self.emb_dim])
        feed = {
            self.h_0: z_h0,
            self.c_0: z_c0,
        }
        outputs = sess.run(self.gen_x, feed)
        if not get_z:
            return outputs
        return outputs, z_h0

    def get_nll(self, sess, batch):
        z_h0 = np.random.uniform(low=-.01, high=.01, size=[self.batch_size, self.emb_dim])
        z_c0 = np.zeros(shape=[self.batch_size, self.emb_dim])
        feed = {
            self.h_0: z_h0,
            self.c_0: z_c0,
            self.x: batch
        }
        return sess.run(self.pretrain_loss, feed)

    def pretrain_step(self, sess, x):
        z_h0 = np.zeros(shape=[self.batch_size, self.emb_dim])
        z_c0 = np.zeros(shape=[self.batch_size, self.emb_dim])
        feed = {
            self.h_0: z_h0,
            self.c_0: z_c0,
            self.x: x
        }
        outputs = sess.run([self.pretrain_updates, self.pretrain_loss], feed_dict=feed)
        return outputs

    def init_matrix(self, shape):
        return tf.random_normal(shape, stddev=5)

    def init_vector(self, shape):
        return tf.zeros(shape)

    def create_recurrent_unit(self, params):
        # Weights and Bias for input and hidden tensor
        self.Wi = tf.Variable(self.init_matrix([self.emb_dim, self.hidden_dim]))
        self.Ui = tf.Variable(self.init_matrix([self.hidden_dim, self.hidden_dim]))
        self.bi = tf.Variable(self.init_matrix([self.hidden_dim]))

        self.Wf = tf.Variable(self.init_matrix([self.emb_dim, self.hidden_dim]))
        self.Uf = tf.Variable(self.init_matrix([self.hidden_dim, self.hidden_dim]))
        self.bf = tf.Variable(self.init_matrix([self.hidden_dim]))

        self.Wog = tf.Variable(self.init_matrix([self.emb_dim, self.hidden_dim]))
        self.Uog = tf.Variable(self.init_matrix([self.hidden_dim, self.hidden_dim]))
        self.bog = tf.Variable(self.init_matrix([self.hidden_dim]))

        self.Wc = tf.Variable(self.init_matrix([self.emb_dim, self.hidden_dim]))
        self.Uc = tf.Variable(self.init_matrix([self.hidden_dim, self.hidden_dim]))
        self.bc = tf.Variable(self.init_matrix([self.hidden_dim]))
        params.extend([
            self.Wi, self.Ui, self.bi,
            self.Wf, self.Uf, self.bf,
            self.Wog, self.Uog, self.bog,
            self.Wc, self.Uc, self.bc])

        def unit(x, hidden_memory_tm1):
            previous_hidden_state, c_prev = tf.unstack(hidden_memory_tm1)

            # Input Gate
            i = tf.sigmoid(
                tf.matmul(x, self.Wi) +
                tf.matmul(previous_hidden_state, self.Ui) + self.bi
            )

            # Forget Gate
            f = tf.sigmoid(
                tf.matmul(x, self.Wf) +
                tf.matmul(previous_hidden_state, self.Uf) + self.bf
            )

            # Output Gate
            o = tf.sigmoid(
                tf.matmul(x, self.Wog) +
                tf.matmul(previous_hidden_state, self.Uog) + self.bog
            )

            # New Memory Cell
            c_ = tf.nn.tanh(
                tf.matmul(x, self.Wc) +
                tf.matmul(previous_hidden_state, self.Uc) + self.bc
            )

            # Final Memory cell
            c = f * c_prev + i * c_

            # Current Hidden state
            current_hidden_state = o * tf.nn.tanh(c)

            return tf.stack([current_hidden_state, c])

        return unit

    def create_output_unit(self, params):
        self.Wo = tf.Variable(self.init_matrix([self.hidden_dim, self.num_vocabulary]))
        self.bo = tf.Variable(self.init_matrix([self.num_vocabulary]))
        params.extend([self.Wo, self.bo])

        def unit(hidden_memory_tuple):
            hidden_state, c_prev = tf.unstack(hidden_memory_tuple)
            logits = tf.matmul(hidden_state, self.Wo) + self.bo
            return logits

        return unit

    def g_optimizer(self, *args, **kwargs):
        return tf.train.AdamOptimizer(*args, **kwargs)

        # Compute the similarity between minibatch examples and all embeddings.
        # We use the cosine distance:

    def set_similarity(self, valid_examples=None, pca=True):
        if valid_examples == None:
            if pca:
                valid_examples = np.array(range(20))
            else:
                valid_examples = np.array(range(self.num_vocabulary))
        self.valid_dataset = tf.constant(valid_examples, dtype=tf.int32)
        self.norm = tf.sqrt(tf.reduce_sum(tf.square(self.g_embeddings), 1, keep_dims=True))
        self.normalized_embeddings = self.g_embeddings / self.norm
        # PCA
        if self.num_vocabulary >= 20 and pca == True:
            emb = tf.matmul(self.normalized_embeddings, tf.transpose(self.normalized_embeddings))
            s, u, v = tf.svd(emb)
            u_r = tf.strided_slice(u, begin=[0, 0], end=[20, self.num_vocabulary], strides=[1, 1])
            self.normalized_embeddings = tf.matmul(u_r, self.normalized_embeddings)
        self.valid_embeddings = tf.nn.embedding_lookup(
            self.normalized_embeddings, self.valid_dataset)
        self.similarity = tf.matmul(self.valid_embeddings, tf.transpose(self.normalized_embeddings))
